Knowledge-enhanced prototypical network with class cluster loss for few-shot relation classification

Few-shot Relation Classification identifies the relation between target entity pairs in unstructured natural language texts by training on a small number of labeled samples. Recent prototype network-based studies have focused on enhancing the prototype representation capability of models by incorporating external knowledge. However, the majority of these works constrain the representation of class prototypes implicitly through complex network structures, such as multi-attention mechanisms, graph neural networks, and contrastive learning, which constrict the model’s ability to generalize. In addition, most models with triplet loss disregard intra-class compactness during model training, thereby limiting the model’s ability to handle outlier samples with low semantic similarity. Therefore, this paper proposes a non-weighted prototype enhancement module that uses the feature-level similarity between prototypes and relation information as a gate to filter and complete features. Meanwhile, we design a class cluster loss that samples difficult positive and negative samples and explicitly constrains both intra-class compactness and inter-class separability to learn a metric space with high discriminability. Extensive experiments were done on the publicly available dataset FewRel 1.0 and 2.0, and the results show the effectiveness of the proposed model.


Introduction
Relation classification (RC) [1] is a basic information extraction mission that tries to comprehend natural language text and determine the relation between entity pairs. For example, "Sofia is the capital city of Bulgaria." represents the relation "capital of" between "Sofia" and "Bulgaria." Traditional relation classification methods are primarily trained using fully supervised data [2], and their model performance is dependent on the amount of labeled data available. However, manual labeling is an expensive, labor-intensive, and time-consuming task, making it challenging to generalize fully supervised models to specific domains. To reduce the expense of labeling work, Mintz et al. [3] developed a distant supervised method to automatically label training samples from a corpus by entity alignment. Unfortunately, the labeling strategy is too absolute: instances containing two identical entities do not necessarily express the same relation as in the knowledge base, which results in some noise in the dataset construction. In addition, due to the uneven distribution of samples in the corpus, many relations under the longtail distribution still lack sufficient training data, resulting in a significant decline in model performance. Therefore, it is necessary to research the relation classification task with insufficient training data. In fact, people only need a few instances to quickly acquire anything new. Even a young child of just a few years can tell what makes each animal unique by looking at a few cartoon pictures of them. This sparked the idea for few-shot learning (FSL) [4], which tries to learn and handle problems with a few labeled data. Early FSL research was primarily focused on computer vision [5]. Recently, FSL has gradually expanded into natural language processing. Because of the diversity and complexity of natural language, the few-shot text classification task is more challenging than the classification of images.
Few-shot Relation Classification (FSRC) focuses on learning the generic knowledge that is built into different relation classes through multi-task learning. This allows the model to quickly process new tasks with a few instances so that it can generalize to real-world application scenarios. Table 1 shows a task example of FSRC. The learning objective of FSRC is to accurately predict the query instance's relation.
Recent years have seen a gradual increase in FSRC research. Sun et al. [6] propose a hierarchical attention prototypical network that restricts the model to the significance of each word in a given instance for relational classification. Ye et al. [7] propose a multi-level matching aggregation network to improve the class prototype. To provide the model with reliable prediction indicators, Yang et al. [8] introduce text descriptions for prototype representation. Han et al. [9] propose a pre-training method that introduces relation labels and relation description information from a large dataset during the pre-training phase to improve sentence representation.
However, the majority of existing methods constrain the generation of class prototypes through complex network structures, which introduce numerous meaningless or even harmful parameters and typically have low bias and high variance [10]. In addition, due to the random sampling strategy for training, the mode with triplet loss [11] is prone to local optimization and erroneous decisions for these challenging few-shot tasks with high semantic similarity. Even though Xiao et al. [12] introduce hard sample mining to triplet loss, it still only considers the relative distance between positive and negative samples [13], which limits the model's ability to handle outlier samples with low semantic similarity. In this paper, a knowledge-enhanced prototypical network (KEPN) is proposed as a way to deal with the problems listed above. KEPN is a model that focuses on improving the model's ability to represent prototypes and handle difficult classification tasks. Specifically, we design a non-weighted prototype enhancement module to explicitly filter redundant information between basic prototypes and relation information and to perform feature selection and complementation between them. Through relation label and description knowledge, the prototype enhancement module adds support information to the prototype representation and reduces the risk of overfitting. In the meantime, we designed a new loss function that focuses on intraclass compactness to help handle outlier samples better. This paper's contributions are summarized as follows: We propose a non-weighted prototype enhancement method to filter and complement the features of prototypes and relation information to improve prototype representation.
We design a class cluster loss function that uses hard positive and negative samples as optimization instances and explicitly constrains both intra-class compactness and inter-class separability to discover a more discriminatory metric space.
Our model achieves competitive performance compared with other baseline models on the FSRC dataset FewRel 1.0, and ablation research and visualization were designed to demonstrate the validity of our model.

Few-shot learning
FSL aims to train a model that could rapidly adapt to the new assignment with a few data, which consisted of optimization-based and metric-based methods. The optimization-based method focuses to discover the optimal initialization parameters of the target task to achieve the best prediction performance in all subsequent tasks with the fewest gradient descent steps possible. MAML [14] is a model-agnostic parameter optimization algorithm that iteratively learns how to update network parameters using second-order gradients. Reptile [15] is a firstorder optimization algorithm that optimizes the gradient inner product of small batches to train the initialization parameters. MAML and Reptile both focus on enhancing the overall learning capacity of the model rather than solving a specific problem. Numerous researchers have proposed a series of improvements based on this foundation. Ravi et al. [16] designed the LSTM-based model that simultaneously discovers initialization parameters and optimization rules. Dong et al. [17] propose a parameter optimization algorithm guided by meta-information. Qu et al. [18] propose a Bayesian method to generalize the various prototypes more effectively. The metric-based method is more laconic and efficient than the optimization-based method, which predicts results based on the distance between query and support samples. Siamese network [19] obtained the input sample feature vectors from two networks with shared weights and calculated the similarity using Euclidean distance. Matching networks [20] extracted the feature vectors from convolutional neural networks and analyzed distance by cosine similarity. Prototypical network [21] represents the prototype by aggregating the supported instances and classifying relations by computing the distance between query instances and each prototype. In recent years, few-shot learning has been applied in various fields. Zhang et al. [22] propose a soft distribution-aware few-shot learning strategy to segment tumors from magnetic resonance imaging data in low-resource scenarios. Feng et al. [23] propose a class-adaptive framework based on MAML to address few-shot anomaly detection in encrypted traffic. Mozafari et al. [24] attempted to combine MAML with a prototypical network and successfully applied it to the few-shot cross-lingual hate speech detection task. This paper is also based on the prototypical network, which reflects a simpler inductive bias.

Prototypical network
Currently, metric-based models for FSRC tasks are primarily concerned with prototypical networks. To mitigate the effects of noisy data, Gao et al. [25] proposed a hybrid network comprised of instance-level and feature-level attention. Yang et al. [26] proposed a method for enhancing entity concepts that combines the information of concepts and sentences at the word level to provide effective relation classification cues. Dong et al. [27] modeled a generic relational network via semantic mapping. Meanwhile, extensive research has been conducted on prototype networks in various other fields as well. Liu et al. [28] proposed an interaction graph-based prototypical network to solve the problem of domain transfer. Yarats et al. [29] introduce reinforcement learning to prototypical network to learn an efficient representation. However, most of these models used complex parameter networks to implicitly constrain the representation of class prototypes, thereby limiting the model's ability to generalize to new tasks. We believe that representing class prototypes with fewer parameters and more explicit ways is advantageous. Therefore, we explore a non-weighted prototype enhancement module that explicitly filters and fuses class prototypes and relational information to enhance the generalizability of the model.

Triplet loss
In the metric space, the distance between prototypes with similar semantics is usually very close. This makes it hard for the model to classify the prototypes. Because of this, some works have used triplet loss to constrain the distance between the samples of different classes. Fan et al. [11] implement triplet loss to constrain the margin distances between different prototypes, allowing the model to learn a metric space with high discriminability. The effectiveness of triplet loss, however, is heavily dependent on the sampling strategy. The random sampling strategy causes the model to disregard samples with high loss, resulting in sluggish convergence in the later stages of the model and an increased propensity for local optima. Xiao et al. [12] presented hard sample mining based on triplet loss, which selects the most distant positive sample and the closest negative sample as optimization instances. It is advantageous for the model to accommodate relations with a high degree of semantic similarity. Unfortunately, only the relative distance was considered. The model remains limited in its ability to deal with outlier samples. This paper will investigate how to constrain intra-class compactness in order to learn a metric space with high discriminability.

Task definition
The majority of the FSRC is currently trained using meta-learning [30], which consists of two phases: M train and M test . In the M train phase, K instances are selected at random from each of the N classes of the dataset as a task, and multiple tasks are combined to form a support set to train the model. Additionally, N instance is selected from the remaining samples of the N classes as query set for validation, which is commonly referred to as N−way−K−shot. M test contains the same task configuration as M train . However, the model learns the target domain data during the M test phase, which does not overlap with the M train class.
Consequently, FSRC can be defined as the task of learning prototype representation from a given support set S ¼ fs n k ; n ¼ 1; . . . ; N; k ¼ 1; . . . ; Kg and predicting the relation y n corresponding to the query instance q n in the query set Q ={q n , n = 1,. . .,N}.

Methodology
In this section, we describe KEPN with class cluster loss in greater depth. The input to KEPN consists of multiple tasks sampled from a dataset, each containing a support and a query set, as depicted in Fig 1. In the meantime, relation labels and relation descriptions will also be transmitted to the encoder. All inputs are encoded using a sentence encoder to generate prototypes and relation information vectors. These vectors will be sent to the prototype enhancement module in order to enhance the prototype's representation. Finally, a class cluster loss function is introduced to encourage the model to learn a metric space with high discriminability.

Sentence encoder
We use one BERT for embedding support and query instances in this work. Following the work of MTB [31], instance sentences are generated by concatenating the start tokens of the two target entities. The support and query instances are denoted as fS n k 2 R 2d ; n ¼ 1; . . . ; N; k ¼ 1; . . . ; Kg and {Q n 2R 2d , n = 1,. . .,N} respectively, and the hidden layer size of the sentence encoder is denoted as d. We concatenate the relation label and description information for each relation class and input them into the sentence encoder to obtain two relation feature vectors. In other words, the sentence feature vector corresponding to the "[CLS]" token represented the global representation of the relation R glo , and the mean of all word feature vectors represented the local representation of the relation R loc .

Prototype enhancement module
Prototypical networks perform classification by calculating the metric space distance between query instances and class prototypes. Typically, prototypes are derived from the average of supported instances, which lack reliable a priori knowledge. As depicted in Fig 2, external knowledge such as relation label and description can provide strong supporting evidence for the relation classification task, and this information is readily available for the FSRC task. Consequently, this paper employs relation label and description as supplementary data to enhance the prototype representation.

Basic prototypes.
Following the typical configuration of prototypical networks, the sentence encoder encodes K instances in N classes. For each relation, the mean of K instances are used as the basic prototype P bas , with the following calculation formula as follows: where f θ () represents the sentence encoder.

PLOS ONE
Knowledge-enhanced prototypical network for few-shot relation classification

Enhanced prototypes.
To prevent the over-fitting problem caused by an excessive number of parameters, as depicted in Fig 3, we propose a non-weighted prototype enhanced module to combine the basic prototype with relational information. First, the global and local representations of the relation are concatenated as the final representation R rep of the relational information as follows: The feature-level similarity between the relational information and the basic prototype is then determined through element-by-element division, and the similarity is mapped as the update signal of the gate G via the hyperbolic tangent function.
Finally, the basic prototype and relationship information are fused via the gate to generate the enhanced prototype representation P enh . The redundant information with a high degree of similarity will be filtered, and the information with a low degree of similarity will be combined

PLOS ONE
Knowledge-enhanced prototypical network for few-shot relation classification through direct addition.
where n ¼ 1; . . . ; N; P n enh 2 R 2d . All the representations used above belong to the R 2d .

Class cluster loss
The performance of the FSRC model is greatly dependent on the distribution of instance vectors in metric space. To learn a metric space with high differentiation, a class cluster loss is designed. As illustrated in Fig 4, the primary objective of class cluster loss is to limit both intraclass compactness and inter-class separability.

Triplet loss.
Our work is improved based on triplet loss. The purpose of triplet loss is to distinguish different classes of samples, which can be defined as: where a, p and n represent the anchor, positive and negative sample respectively. The margin is a hyperparameter to constrain the distance between dissimilar samples, and the D denotes the Euclidean distance between two samples.

Hard sample mining.
Random sampling can lead to a model that performs inconsistently and takes a long time to learn, so it's best to think about hard sample mining. Following the work of Xiao et al. [12], we choose as optimization objects the farthest positive example and the closest negative example. The difference is that we chose the enhanced prototype as anchor, which makes the data distribution more stable. The hard triplet loss can be defined as:

Class cluster loss.
Triplet Loss only considers the margins of inter-class separability, which causes the model to be limited in processing outlier samples. Thus, we design a class cluster loss based on triplet loss. First, the central distance c is derived by mean the distances of similar samples to the prototype, as follows: Then, in order to make the metric space more discriminatory, we imposed explicit constraints both on intra-class compactness and inter-class separability. The class cluster loss is defined as follows: Finally, in order to find a superior classification plane, we combine cross-entropy loss and class cluster loss for co-training. The joint loss is defined as follows: where z y represents the probability that the query instance belongs to relation class y, and the α is a hyperparameter to balance the loss function.

Experiments
In this section, we talk about the model comparison experiments, ablation studies, and visualization on the public FSRC dataset FewRel 1.0 [32] and FewRel 2.0 [33].

Dataset.
Our model is evaluated on the FewRel 1.0 and FewRel 2.0. FewRel 1.0 is a generic domain dataset, which is typically used for few-shot relation classification tasks. Few-Rel 1.0 utilizes Wikidata as a knowledge base in conjunction with distant supervision to identify datasets containing target relations from the Wikipedia news corpus. The final training, validation, and test sets contain 64, 16 and 20 relations, respectively, with 700 instances per relation. Each instance has an average of 24.99 tokens, for a total of 124,577 unique tokens. FewRel 2.0 is a more challenging dataset designed to assess the domain adaptation and "not of the above" recognition abilities of FSRC models. This dataset is composed of specialized biomedical knowledge, consisting of a total of 25 relations, each containing 100 instances. In this paper, we shall employ FewRel 2.0 to evaluate the domain generalization capabilities of the model.

Evaluation.
In accordance with FewRel settings, our model is trained based on the N−way−K−shot. The N and K respectively represent the number of relation classes and instances, and accuracy is used to evaluate the performance of the model as follow: where Y true and Y represent the number of correctly classified and the total number to be classified, respectively.

Comparable models
Our model is compared to twelve other models, including two CNN-based and ten BERTbased models. Specifically, the following models employ CNN as the sentence encoder: 1) Proto-HATT [25], a hybrid attention prototypical network, focuses on solving the noise problem. 2) MLMAN [7], an interactive prototype network through multi-level matching and aggregation. The following models use BERT as the sentence encoder: 3) BERT-PAIR [33], a model that pairs support instances with corresponding query instances and feeds them into BERT for relation prediction. 4) Proto-BERT [21], a network that classifies relations based on the query instance's distance from the class prototype. 5) REGRAB [18], a Bayesian metalearning method to generalize various prototypes more effectively. 6) TD-Proto [8], a method that utilizes a weighted gate mechanism to combine entity and relation description information in order to produce a knowledge-aware class prototype. 7) CTEG [34], a confusion-aware training method that employs Kullback-Leibler Divergence to improve the capacity to differentiate between true and confusion model relations. 8) ConceptFERE [26], a method that introduces entity concept information to provide strong supporting evidence for relation classification. 9) HCRP [9], a model that introduces relation label to contrastive learning prototype representations and focuses the model on difficult tasks by increasing the weight of difficult samples. 10) MTB [31], a pre-training model for learning relation representation from a large unsupervised corpus using entity-linked text based on Harris' assumptions regarding distributional properties. 11) CP [35], a contrastive pre-training model. 12) MapRE [27], a relation mapping network that takes advantage of label knowledge.

Experiment setting
In Table 2, all hyperparameters are listed. As our experimental environment, we employed Transformer 4.7.0 and PyTorch 1.7.1 and trained on an RTX 3090Ti GPU. The pre-trained models BERT-base-uncased and CP are used, respectively, as sentence encoders in our model, where CP is a contrastive pre-training model based on BERT. The pre-training weights of the sentence encoder will be utilized as initialization parameters for fine-tuning on the FewRel dataset. The AdamW algorithm is utilized to solve the optimization problem, with the learning rates set to 1e-5 and 5e-6. We set the hidden layer to 768 and the batch size to 4. Iterations of training and validation are 20,000 and 4, respectively.  Table 3, which includes CNN-based and BERT-based methods, and four conventional N−way−K−shot settings are adopted. Notably, the top half of the BERTbased models utilized the original BERT, while the bottom half models performed additional pre-training on the BERT. Among them, Proto-BERT is the basic model that does not introduce our proposed prototype enhancement module and class cluster loss. In this experiment, we apply our model to BERT and CP. Through Table 3, we can observe three results. First, KEPN reaches optimal accuracy on the test set. Second, our model achieves higher improvements on the tasks of 5−way−1−shot and 10−way−1−shot, which shows that KEPN is more compatible for few-shot tasks. Third, KEPN achieves a huge improvement compared to the basic model. These results demonstrate the effectiveness of our approach.

Result and discussion
Our model achieved the best performance in the general domain. To further verify the model's transferability, as shown in Table 4, we conducted tests on the biomedical dataset Few-Rel 2.0. It should be noted that in order to evaluate the domain generalization ability of the model, we only conducted training on the FewRel 1.0 dataset, and the FewRel 2.0 dataset was only used for testing. By observing Table 4, we can derive two results. Firstly, compared to the test results in the general domain, all models showed varying degrees of decline in performance. Secondly, KEPN outperformed other models in all tasks. These results demonstrate the generalization and effectiveness of our model and also highlight the importance of an enhanced prototype representation for FSRC tasks.

Performance on different tasks.
In order to verify the applicability of KEPN, we compare it with other models under different task settings on FewRel 1.0 validation set. First, K is changed from 1 to 10, when N to 5. Then, the range of N is 3-12 when K is fixed to 1. As shown in Fig 5, the model's performance improves as the number of samples increases. It indicates that additional information can add more useful support clues to the prototype representation, which is consistent with our concept of utilizing relational information to improve the class prototype representation. The model's accuracy degrades as the number of classes increases. This is expected because the model needs to account for more relational differences.
Compared to the basic model, our model achieves greater accuracy, validating its robustness and generalizability.

Statistical test.
To further investigate whether there is a significant difference between the proposed method and the baseline model, we conducted a paired sample t-test [36]. Specifically, we initialized two models with identical random seeds and trained and validated them on the four "N-way-K-shot" tasks to obtain paired results. The training and validation iterations were set to 1000. This process was repeated seven times with different random seeds. Ultimately, we obtained seven pairs of test data. As shown in Table 5, the two-tailed Pvalues for both tasks were less than 0.05, indicating that our model significantly outperformed the baseline model.

Ablation study
In Table 6, we discussed the impact of relation information, various fusion mechanisms, and loss functions on the prototype network's performance. KEPN refers to the comprehensive  model with the non-weighted prototype enhancement module and class cluster loss. The types "without relation info." and "with relation info." denote, respectively, the prototype network without introducing relation information and the various alternatives to the non-weighted fusion mechanism for relation information. The "w/ loss" indicates experimentation with alternative loss functions as opposed to class cluster loss. According to Table 6, three results can be observed. First, relation information is essential for prototype representation, and the model's performance is enhanced most by the features that connect global and local relations. Second. the direct addition, concatenate, weighted gate, and hybrid attention perform poorly in terms of feature fusion. One possible explanation is that these methods introduce an excessive amount of redundant data or features, leading to an overfitting of the model. Thirdly, owing to our proposed constraints on both intra-class compactness and inter-class separability, our model achieves optimal performance relative to cross entropy loss and triplet loss. These results indicate that KEPN could more effectively represent prototypes and metric space.

Visualization
To acquire a more intuitive grasp of KEPN, a collection of data was graphically represented. As depicted in Fig 6, the vanilla model labeled ProtoNet has a limited capacity to represent instance vectors. After introducing the triplet loss, the model named ProtoTriplet is able to distinguish between instances of various relational classes. Nevertheless, there are still outlier samples that are hard to classify, and the intraclass distance is not compact enough. The KEPN model with the prototype enhancement module and class cluster loss can learn a metric space with greater discriminability than the other model, enabling it to predict the relationships between entities more precisely. We believe that the few parameters and explicit constraints of the few-shot model are advantageous.

PLOS ONE
Knowledge-enhanced prototypical network for few-shot relation classification

Case study
We displayed the classification results of several instances in the FewRel 1.0 validation set. Table 7 displays the five instances where the Proto-BERT classification was incorrect but our model classification was accurate. In the third instance of Table 7, Proto-BERT without prototype enhancement and class cluster loss incorrectly classifies the relation as "part of," whereas our model accurately predicts the relation as "member of." This is a challenging example because the semantic relationships of "part of" and "member of" are very similar. From Table 7, we could deduce that KEPN is capable of simulating the subtle distinctions between relations.

Hyperparameter study
In this section, we examined the impact of varying the numerical values of the hyperparameters margin and alpha on model performance. BERT was used as the pre-training model for this experiment, with 3000 and 1000 iterations for training and validation, respectively. Fig 7 shows the effects of different alpha and margin values on model performance for the 5-w-5-s and 10-w-5-s tasks. Firstly, when the margin was set to 0.5, alpha was increased from 0.1 to 1. Then, with alpha fixed at 0.5, the margin range was set to 0.1-1. We can observe that the model's performance varies with changes in both parameter values, indicating that both alpha and margin affect the model's performance. The best model performance is achieved when the margin and alpha values are close to 0.4. However, in terms of accuracy, the model is not very sensitive to the values of alpha and margin. This may be due to the fact that there are more simple tasks which cause cross-entropy loss to play a more important role in the early stages of training.

Conclusions and future work
This paper focuses on FSRC and proposes a knowledge-enhanced prototypical network. The central idea of our model is to utilize the feature-level similarity of the prototype and relational information to filter and fuse the information through a non-weighted gate mechanism, which enhances the prototype representation while avoiding the overfitting problem caused by too many parameters. In addition, we design a class cluster loss that optimizes those positive and negative samples that are difficult to classify during the training process and explicitly constrains both intra-class compactness and inter-class separability to learn a more discriminative metric space. As a result, KEPN attains the best accuracy on the FewRel 1.0. We believe that fewer parameters and explicit constraints are meaningful and can be generalized to other fewshot classification tasks. However, there are still several limitations that must be considered. Firstly, the model's performance will be limited when applied to more complex datasets. Secondly, the model is dependent on the fixed relation classes during the learning process, which is disadvantageous for continuous learning. In future research, we will focus on developing more flexible and adaptable methods to handle a wider range of relation classification tasks, such as those in the biological or civil language processing domains. Additionally, introducing incremental learning into few-shot classification tasks may be necessary to address the continuous addition of new relation classes in real-world scenarios.